{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#multi channel CNN for sentiment analysis\n",
    "from nltk.corpus import stopwords\n",
    "from string import punctuation,digits\n",
    "import pandas as pd\n",
    "import json\n",
    "import numpy as np\n",
    "import re\n",
    "#import word2vecReader as godin_embedding\n",
    "from keras.preprocessing.text import Tokenizer\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "from keras.utils.vis_utils import plot_model\n",
    "from keras.utils import to_categorical\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "import torch\n",
    "from keras.layers import Dense\n",
    "from keras.layers import Flatten\n",
    "from keras.layers import Dropout\n",
    "from keras.layers import Embedding\n",
    "from keras.layers.convolutional import Conv1D\n",
    "from keras.layers.convolutional import MaxPooling1D\n",
    "from keras.optimizers import Adam\n",
    "from keras.layers.merge import concatenate\n",
    "import keras.backend as K\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.rcParams[\"figure.figsize\"] = [5,5]\n",
    "plt.style.use('seaborn-notebook')\n",
    "from scipy.interpolate import interp1d\n",
    "from sklearn.metrics import mean_squared_error,r2_score\n",
    "#from aspect_specific_prob import get_normalized_sentence_relation_vector\n",
    "from math import sqrt\n",
    "#from gensim.models import KeyedVectors\n",
    "from sklearn.model_selection import KFold\n",
    "import nltk\n",
    "nltk.download('stopwords')\n",
    "import json\n",
    "from embedding_as_service.text.encode import Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Training and Testing Data \n",
    "with open('imageDataset/gossip/testJson.json', 'r') as f:\n",
    "    testData = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = {}\n",
    "lengths = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'transformers'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[1;32m/media/zhengpeng/epan/program/python/SpotFakePlusAAAI/prepareEmbeddings.ipynb Cell 4\u001b[0m line \u001b[0;36m2\n\u001b[1;32m      <a href='vscode-notebook-cell:/media/zhengpeng/epan/program/python/SpotFakePlusAAAI/prepareEmbeddings.ipynb#W3sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39m# XLNET Model from Hugging Face\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/media/zhengpeng/epan/program/python/SpotFakePlusAAAI/prepareEmbeddings.ipynb#W3sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtransformers\u001b[39;00m \u001b[39mimport\u001b[39;00m XLNetModel, XLNetTokenizer\n\u001b[1;32m      <a href='vscode-notebook-cell:/media/zhengpeng/epan/program/python/SpotFakePlusAAAI/prepareEmbeddings.ipynb#W3sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mos\u001b[39;00m\n\u001b[1;32m      <a href='vscode-notebook-cell:/media/zhengpeng/epan/program/python/SpotFakePlusAAAI/prepareEmbeddings.ipynb#W3sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m os\u001b[39m.\u001b[39menviron[\u001b[39m\"\u001b[39m\u001b[39mCUDA_DEVICE_ORDER\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mPCI_BUS_ID\u001b[39m\u001b[39m\"\u001b[39m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'transformers'"
     ]
    }
   ],
   "source": [
    "# XLNET Model from Hugging Face\n",
    "from transformers import XLNetModel, XLNetTokenizer\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "import threading\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "tokenizer = XLNetTokenizer.from_pretrained('xlnet/xlnet-base-cased')\n",
    "model = XLNetModel.from_pretrained('xlnet/xlnet-base-cased')\n",
    "model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate embeddings for a paragraph\n",
    "def torchEmbeddings(data):\n",
    "    embed = []\n",
    "    sentences= sent_tokenize(data)\n",
    "    for sent in sentences:\n",
    "        input_ids = torch.tensor(tokenizer.encode(sent)).unsqueeze(0)\n",
    "        input_ids = input_ids.to('cuda')\n",
    "        tempEmbedding = model(input_ids)\n",
    "        #temp=en.encode([sent],pooling='reduce_mean')\n",
    "        embed.append(tempEmbedding)\n",
    "    return embed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dump Embeddings into a pickle file\n",
    "import pickle\n",
    "with open('gossip/finalTrainEmbeddings.pkl', 'wb') as f:\n",
    "    pickle.dump(embeddings, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Some optimizations for large Embedding FIles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "queryList = [list(i) for i in np.array_split(list(set(testData)), 2)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exceptions = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\"\"\"with open('gossip/finalTrainEmbeddings1.pkl', 'wb') as f:\n",
    "    pickle.dump(embeddings, f)\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    embeddings = {}\n",
    "    for i in tqdm(queryList[1]):\n",
    "        try:\n",
    "            embeddings[i] = torchEmbeddings(i)\n",
    "        except:\n",
    "            exceptions.append(i)\n",
    "    with open('gossip/finalTestEmbeddings2.pkl', 'wb') as f:\n",
    "        pickle.dump(embeddings, f)\n",
    "            #print(len(sent_tokenize(i)))\n",
    "            #lengths.append(len(sent_tokenize(i)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('imageDataset/politi/testJson.json', 'r') as f:\n",
    "    testData = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "testEmbeddings = {}\n",
    "with torch.no_grad():\n",
    "    for i in tqdm(testData):\n",
    "        testEmbeddings[i] = torchEmbeddings(i)\n",
    "        #lengths.append(len(sent_tokenize(i)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For politifact\n",
    "import pickle\n",
    "with open('politifact/finalTrainEmbeddings.pkl', 'wb') as f:\n",
    "    pickle.dump(embeddings, f)\n",
    "with open('politifact/finalTestEmbeddings.pkl', 'wb') as f:\n",
    "    pickle.dump(testEmbeddings, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
